# %%
import pandas as pd
import numpy as np
from openai import OpenAI
from tqdm import tqdm

import start

prompts_file = start.MAIN_DIR + "prompts_characterization.xlsx"
prompts = pd.read_excel(prompts_file, sheet_name=None)

text_file = start.MAIN_DIR + "data/clean/character_classifications_gold.xlsx"
text_df = pd.read_excel(text_file)
text_df = text_df[text_df.text.notnull()]
text_df = text_df[text_df.set == "dev"]

client = OpenAI(api_key=start.OPENAI_API_KEY)
# Replace with fine-tuned model key
MODEL = ""
# %%
# Specify the desired prompt name, or leave empty for all
SHEET_NAME = "Zero Shot 1"

for sheet_name, sheet_data in prompts.items():
    if ((sheet_name == SHEET_NAME) | (SHEET_NAME == "")) & ("Sheet" not in sheet_name):
        print(sheet_name)

        prompt = []
        for message in sheet_data.index:
            dict_entry = {
                "role": sheet_data.loc[message, "role"],
                "content": sheet_data.loc[message, "content"],
            }
            prompt.append(dict_entry)

        responses = []
        classifications = []
        for text in tqdm(text_df.text):
            messages = prompt + [{"role": "user", "content": text}]

            response = client.chat.completions.create(
                model=MODEL,
                messages=messages,
                temperature=0.00,
            )

            cleaned_response = response.choices[0].message.content
            responses.append(cleaned_response)

        text_df["prompt"] = sheet_name
        text_df["response"] = responses
        text_df["agreement"] = np.where(
            text_df.response == text_df.character_gold, 1, 0
        )
        text_df = text_df[
            [
                "unique_id",
                "text",
                "prompt",
                "response",
                "character_gold",
                "agreement",
            ]
        ]

        with pd.ExcelWriter(
            start.MAIN_DIR + f"data/clean/gpt_classifications_characters_dev.xlsx",
            engine="openpyxl",
            mode="a",
            if_sheet_exists="replace",
        ) as writer:
            text_df.to_excel(writer, sheet_name=sheet_name + "ft", index=False)

# %%
